import torch
import spconv.pytorch as spconv

try:
    import ocnn
except ImportError:
    ocnn = None
from addict import Dict
from typing import List

from pointcept.models.utils.serialization import encode
from pointcept.models.utils import (
    offset2batch,
    batch2offset,
    offset2bincount,
    bincount2offset,
)


class Point(Dict):
    """
    Point Structure of Pointcept

    A Point (point cloud) in Pointcept is a dictionary that contains various properties of
    a batched point cloud. The property with the following names have a specific definition
    as follows:

    - "coord": original coordinate of point cloud;
    - "grid_coord": grid coordinate for specific grid size (related to GridSampling);
    Point also support the following optional attributes:
    - "offset": if not exist, initialized as batch size is 1;
    - "batch": if not exist, initialized as batch size is 1;
    - "feat": feature of point cloud, default input of model;
    - "grid_size": Grid size of point cloud (related to GridSampling);
    (related to Serialization)
    - "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range;
    - "serialized_code": a list of serialization codes;
    - "serialized_order": a list of serialization order determined by code;
    - "serialized_inverse": a list of inverse mapping determined by code;
    (related to Sparsify: SpConv)
    - "sparse_shape": Sparse shape for Sparse Conv Tensor;
    - "sparse_conv_feat": SparseConvTensor init with information provide by Point;
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # If one of "offset" or "batch" do not exist, generate by the existing one
        if "batch" not in self.keys() and "offset" in self.keys():
            self["batch"] = offset2batch(self.offset)
        elif "offset" not in self.keys() and "batch" in self.keys():
            self["offset"] = batch2offset(self.batch)

    def serialization(self, order="z", depth=None, shuffle_orders=False):
        """
        Point Cloud Serialization

        relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]
        """
        assert "batch" in self.keys()
        if "grid_coord" not in self.keys():
            # if you don't want to operate GridSampling in data augmentation,
            # please add the following augmentation into your pipline:
            # dict(type="Copy", keys_dict={"grid_size": 0.01}),
            # (adjust `grid_size` to what your want)
            assert {"grid_size", "coord"}.issubset(self.keys())
            self["grid_coord"] = torch.div(
                self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
            ).int()

        if depth is None:
            # Adaptive measure the depth of serialization cube (length = 2 ^ depth)
            depth = int(self.grid_coord.max()).bit_length()
        self["serialized_depth"] = depth
        # Maximum bit length for serialization code is 63 (int64)
        assert depth * 3 + len(self.offset).bit_length() <= 63
        # Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position.
        # Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3
        # cube with a grid size of 0.01 meter. We consider it is enough for the current stage.
        # We can unlock the limitation by optimizing the z-order encoding function if necessary.
        assert depth <= 16

        # len_coord, len_u_coord, d = self.count_unique_coords_hash(self.grid_coord)
        # The serialization codes are arranged as following structures:
        # [Order1 ([n]),
        #  Order2 ([n]),
        #   ...
        #  OrderN ([n])] (k, n)
        # import pdb;pdb.set_trace()
        if "nps" in order:
            self._generate_new_grid_coord()
            # import pdb;pdb.set_trace()
            code = [
                encode(self.new_coord, self.batch, order=order_) if order_ == "nps" or order_ == "random" else
                encode(self.grid_coord, self.batch, depth, order=order_)
                for order_ in order
            ]
        else:
            code = [
                encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order
            ]
        code = torch.stack(code)
        order = torch.argsort(code)
        inverse = torch.zeros_like(order).scatter_(
            dim=1,
            index=order,
            src=torch.arange(0, code.shape[1], device=order.device).repeat(
                code.shape[0], 1
            ),
        )

        if shuffle_orders:
            perm = torch.randperm(code.shape[0])
            code = code[perm]
            order = order[perm]
            inverse = inverse[perm]

        self["serialized_code"] = code
        self["serialized_order"] = order
        self["serialized_inverse"] = inverse  


    def count_unique_coords_hash(self, coords):
        """使用哈希映射统计唯一点数量"""
        if coords.numel() == 0:
            return 0, 0, 0
        
        # 确保使用足够大的数据类型避免溢出
        dtype = torch.int64
        max_val = torch.max(coords).item() + 1
        
        # 将多维坐标映射到一维哈希值
        if coords.shape[1] == 3:  # 3D坐标
            key = coords[:, 0].to(dtype) * (max_val ** 2) + \
                  coords[:, 1].to(dtype) * max_val + \
                  coords[:, 2].to(dtype)
        elif coords.shape[1] == 2:  # 2D坐标
            key = coords[:, 0].to(dtype) * max_val + \
                  coords[:, 1].to(dtype)
        else:
            raise ValueError("Unsupported coordinate dimension")
        
        # 统计唯一键的数量
        unique_count = torch.unique(key).numel()
        
        return len(coords), unique_count, len(coords) - unique_count
    
    def sparsify(self, pad=96):
        """
        Point Cloud Serialization

        Point cloud is sparse, here we use "sparsify" to specifically refer to
        preparing "spconv.SparseConvTensor" for SpConv.

        relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"]

        pad: padding sparse for sparse shape.
        """
        assert {"feat", "batch"}.issubset(self.keys())
        if "grid_coord" not in self.keys():
            # if you don't want to operate GridSampling in data augmentation,
            # please add the following augmentation into your pipline:
            # dict(type="Copy", keys_dict={"grid_size": 0.01}),
            # (adjust `grid_size` to what your want)
            assert {"grid_size", "coord"}.issubset(self.keys())
            self["grid_coord"] = torch.div(
                self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
            ).int()
        if "sparse_shape" in self.keys():
            sparse_shape = self.sparse_shape
        else:
            sparse_shape = torch.add(
                torch.max(self.grid_coord, dim=0).values, pad
            ).tolist()
        sparse_conv_feat = spconv.SparseConvTensor(
            features=self.feat,
            indices=torch.cat(
                [self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1
            ).contiguous(),
            spatial_shape=sparse_shape,
            batch_size=self.batch[-1].tolist() + 1,
        )
        self["sparse_shape"] = sparse_shape
        self["sparse_conv_feat"] = sparse_conv_feat


    def _generate_new_grid_coord(self):
        """
        根据栅格坐标计算每个栅格内点的重心，生成新的坐标
        """
        # 获取必要的数据
        grid_coord = self["grid_coord"]
        batch = self["batch"]
        device = grid_coord.device
        
        # 确保batch维度正确
        assert len(batch) == len(grid_coord), "Batch size mismatch with grid coordinates"
        
        # 创建一个唯一标识符，用于区分不同batch中的相同grid_coord
        # 使用batch_id * 10^18 + grid_coord的方式组合，确保唯一性
        batch_factor = 10**18
        unique_id = batch * batch_factor + \
                    grid_coord[:, 0] * 10**12 + \
                    grid_coord[:, 1] * 10**6 + \
                    grid_coord[:, 2]
        
        # 获取唯一的栅格坐标和它们的逆映射
        unique_grid_id, inverse_indices = torch.unique(unique_id, return_inverse=True)
        
        # 计算每个唯一栅格内点的数量
        grid_counts = torch.bincount(inverse_indices)
        
        # 计算每个栅格内点的重心坐标
        # 创建一个临时张量用于累加坐标值
        num_grids = len(unique_grid_id)
        sum_coords = torch.zeros((num_grids, 3), dtype=torch.float32, device=device)
        
        # 累加每个栅格内的坐标值
        # 这里使用了scatter_add_方法，对于大规模点云数据可能需要优化
        # 例如可以考虑使用分组操作或其他高效算法
        sum_coords.index_add_(0, inverse_indices, self["coord"].float())
        
        # 计算重心坐标：总和除以点的数量
        new_coords = sum_coords / grid_counts.unsqueeze(1)
        
        # 将新坐标映射回原始点云的顺序
        # 每个点的新坐标等于其所在栅格的重心坐标
        self["new_coord"] = new_coords[inverse_indices]
        
        # 记录每个栅格的点数量和映射关系，方便后续使用
        self["grid_counts"] = grid_counts
        self["grid_inverse_map"] = inverse_indices
        self["grid_unique_id"] = unique_grid_id

    def octreelization(self, depth=None, full_depth=None):
        """
        Point Cloud Octreelization

        Generate octree with OCNN
        relay on ["grid_coord", "batch", "feat"]
        """
        assert (
            ocnn is not None
        ), "Please follow https://github.com/octree-nn/ocnn-pytorch install ocnn."
        assert {"feat", "batch"}.issubset(self.keys())
        # add 1 to make grid space support shift order
        if "grid_coord" not in self.keys():
            # if you don't want to operate GridSampling in data augmentation,
            # please add the following augmentation into your pipline:
            # dict(type="Copy", keys_dict={"grid_size": 0.01}),
            # (adjust `grid_size` to what your want)
            assert {"grid_size", "coord"}.issubset(self.keys())
            self["grid_coord"] = torch.div(
                self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc"
            ).int()
        if depth is None:
            if "depth" in self.keys():
                depth = self.depth
            else:
                depth = int(self.grid_coord.max() + 1).bit_length()
        if full_depth is None:
            full_depth = 1
        self["depth"] = depth
        assert depth <= 16  # maximum in ocnn

        # [0, 2**depth] -> [0, 2] -> [-1, 1]
        coord = self.grid_coord / 2 ** (self.depth - 1) - 1.0
        point = ocnn.octree.Points(
            points=coord,
            features=self.feat,
            batch_id=self.batch.unsqueeze(-1),
            batch_size=self.batch[-1] + 1,
        )
        octree = ocnn.octree.Octree(
            depth=depth,
            full_depth=full_depth,
            batch_size=self.batch[-1] + 1,
            device=coord.device,
        )
        octree.build_octree(point)
        octree.construct_all_neigh()

        query_pts = torch.cat([self.grid_coord, point.batch_id], dim=1).contiguous()
        inverse = octree.search_xyzb(query_pts, depth, True)
        assert torch.sum(inverse < 0) == 0  # all mapping should be valid
        inverse_ = torch.unique(inverse)
        order = torch.zeros_like(inverse_).scatter_(
            dim=0,
            index=inverse,
            src=torch.arange(0, inverse.shape[0], device=inverse.device),
        )
        self["octree"] = octree
        self["octree_order"] = order
        self["octree_inverse"] = inverse
